import numpy as np
import scipy.signal
from gym.spaces import Box, Discrete

import torch
import torch.nn as nn
from torch.distributions.normal import Normal 
from torch.distributions import MultivariateNormal
from torch.distributions.categorical import Categorical
import sys
from utils import LatticeStateDependentNoiseDistribution

def mlp(sizes, activation, output_activation=nn.Identity):
    layers = []
    for j in range(len(sizes)-1):
        act = activation if j < len(sizes)-2 else output_activation
        layers += [nn.Linear(sizes[j], sizes[j+1]), act()]
    return nn.Sequential(*layers)



class MLPLatticeActor(nn.Module):

    def __init__(self, obs_dim, act_dim, hidden_sizes, activation,alpha):
        super().__init__()
        latent_sde_dim = hidden_sizes[1]
        log_std_init = 0
        self.mu_net = mlp([obs_dim] + list(hidden_sizes), activation)
        self.mean_actions = nn.Linear(latent_sde_dim, act_dim)
        self.clipped_mean_actions = nn.Sequential(self.mean_actions,nn.Tanh())
        self.lattice = LatticeStateDependentNoiseDistribution(act_dim,latent_sde_dim,alpha=alpha)
        self.lattice.mean_actions_net = self.mean_actions
        self.log_std = torch.ones(latent_sde_dim, latent_sde_dim + act_dim)
        # Transform it into a parameter so it can be optimized
        self.log_std = nn.Parameter(self.log_std * log_std_init, requires_grad=True)
        self.alpha = alpha

    def _distribution(self, obs):
        latent_sde = self.mu_net(obs)
        mean_actions = self.clipped_mean_actions(latent_sde)
        distribution = self.lattice.distribution(mean_actions,self.log_std,latent_sde)
        return distribution

    
    def forward(self, obs, act=None,mask=None):
        # Produce action distributions for given observations, and 
        # optionally compute the log likelihood of given actions under
        # those distributions.
        pi = self._distribution(obs)
        logp_a = None
        if act is not None:
            logp_a = pi.log_prob(act)
        return pi, logp_a


class MLPCritic(nn.Module):

    def __init__(self, obs_dim, hidden_sizes, activation):
        super().__init__()
        self.v_net = mlp([obs_dim] + list(hidden_sizes) + [1], activation)

    def forward(self, obs):
        return torch.squeeze(self.v_net(obs), -1) # Critical to ensure v has right shape.



class MLPActorCritic(nn.Module):


    def __init__(self, observation_space, action_space, 
                 hidden_sizes=(64,32), activation=nn.ReLU, alpha = 1):
        super().__init__()
        obs_dim = observation_space.shape[0]

        # policy builder depends on action space
        if isinstance(action_space, Box):
            self.pi = MLPLatticeActor(obs_dim, action_space.shape[0], hidden_sizes, activation, alpha)


        # build value function
        self.v  = MLPCritic(obs_dim, hidden_sizes, activation)

    def step(self, obs,mask):
        with torch.no_grad():
            pi = self.pi._distribution(obs)
            a = pi.sample()
            a = torch.mul(pi.sample(),mask)
            logp_a = pi.log_prob(a)
            v = self.v(obs)
        return a.cpu().numpy(), v.cpu().numpy(), logp_a.cpu().numpy()

    def act(self, obs):
        return self.step(obs)[0]